import tensorflow as tf 

# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True

# sess = tf.Session(config=config)
sess = tf.Session()
from tensorflow.contrib import rnn
import numpy as np


decay = 0.85
max_epoch = 5
max_max_epoch = 10
timestep_size = max_len = 32
vocab_size = 5159
input_size = embedding_size = 64
class_num = 5
hidden_size = 128
layer_num = 2
max_grad_norm = 5.0

lr = tf.placeholder(tf.float32)
keep_prob = tf.placeholder(tf.float32)
batch_size = tf.placeholder(tf.int32)
# 模型保存位置
model_save_path = 'ckpt/bi-lstm.ckpt'

def weight_variable(shape):
    initial = tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1,shape=shape)
    return tf.Variable(initial)

x_inputs = tf.placeholder(tf.float32,[None,timestep_size],name='x_input')
y_inputs = tf.placeholder(tf.float32,[None,timestep_size],name='y_input')

def bi_lstm(x_inputs):
    embedding = tf.get_variable("embedding",[vocab_size,embedding_size],dtype=tf.float32)
    inputs = tf.nn.embedding_lookup(embedding,x_inputs)

    # 1.LSTM层
    lstm_fw_cell = rnn.BasicLSTMCell(hidden_size,forget_bias=1.0,state_is_tuple=True)
    lstm_bw_cell = rnn.BasicLSTMCell(hidden_size,forget_bias=1.0,state_is_tuple=True)

    # 2.dropout
    lstm_fw_cell = rnn.DropoutWrapper(cell=lstm_fw_cell,input_keep_prob=1.0,output_keep_prob=keep_prob)
    lstm_bw_cell = rnn.DropoutWrapper(cell=lstm_bw_cell,input_keep_prob=1.0,output_keep_prob=keep_prob)

    # 3多层LSTM
    cell_fw = rnn.MultiRNNCell([lstm_fw_cell]*layer_num,state_is_tuple=True)
    cell_bw = rnn.MultiRNNCell([lstm_bw_cell]*layer_num,state_is_tuple=True)

    # 4初始状态
    initial_state_fw = cell_fw.zero_state(batch_size=batch_size,dtype=tf.float32)
    initial_state_bw = cell_bw.zero_state(batch_size=batch_size,dtype=tf.float32)

    # 5bi-lstm计算
    with tf.variable_scope('bidirectional_rnn'):
        # *** 下面，两个网络是分别计算 output 和 state 
        # Forward direction
        outputs_fw = list()
        state_fw = initial_state_fw
        with tf.variable_scope('fw'):
            for timestep in range(timestep_size):
                if timestep > 0:
                    tf.get_variable_scope().reuse_variables()
                    (outputs_fw,state_fw) = cell_fw(inputs[:,timestep,:],state_fw)
                    outputs_fw.append(outputs_fw)
        # backward direction
        outputs_bw = list()
        state_bw = initial_state_bw
        with tf.variable_scope('bw') as bw_scope:
            inputs = tf.reverse(inputs, [1])
            for timestep in range(timestep_size):
                if timestep > 0:
                    tf.get_variable_scope().reuse_variables()
                (output_bw, state_bw) = cell_bw(inputs[:, timestep, :], state_bw)
                outputs_bw.append(output_bw)
        # *** 然后把 output_bw 在 timestep 维度进行翻转
        # outputs_bw.shape = [timestep_size, batch_size, hidden_size]
        outputs_bw = tf.reverse(outputs_bw, [0])
        # 把两个oupputs 拼成 [timestep_size, batch_size, hidden_size*2]
        output = tf.concat([outputs_fw, outputs_bw], 2)  
        # output.shape 必须和 y_input.shape=[batch_size,timestep_size] 对齐
        output = tf.transpose(output, perm=[1,0,2])
        output = tf.reshape(output, [-1, hidden_size*2])
        # ***********************************************************

    softmax_w = weight_variable([hidden_size * 2, class_num]) 
    softmax_b = bias_variable([class_num]) 
    logits = tf.matmul(output, softmax_w) + softmax_b
    return logits

y_pred = bi_lstm(x_inputs)
corrent_prediction = tf.equal(tf.cast(tf.argmax(y_pred,1),tf.int32),tf.reshape(y_inputs,[-1]))
accuracy = tf.reduce_mean(tf.cast(corrent_prediction,tf.float32))
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(y_inputs,[-1]),logits=y_pred))

# 优化求解
# 获取模型所有参数
tvars = tf.trainable_variables()
# 获取损失函数对于每个参数的梯度
grads,_ = tf.clip_by_global_norm(tf.gradients(cost,tvars),max_grad_norm)
# 优化器
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
# 梯度下降计算
train_op = optimizer.apply_gradients(zip(grads,tvars),global_step=tf.contrib.framework.get_or_create_gllobal_step())

print('Finished creating the bi-lstm model')